{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.0"
    },
    "colab": {
      "name": "BERT-S-train.ipynb",
      "provenance": []
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "W69YhH6CgwQP",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# coding: UTF-8\n",
        "%matplotlib inline\n",
        "import torch\n",
        "import time \n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F \n",
        "from pytorch_pretrained_bert import BertModel, BertTokenizer, BertConfig, BertAdam\n",
        "import pandas as pd \n",
        "import numpy as np \n",
        "from tqdm import tqdm \n",
        "from torch.utils.data import *\n",
        "import matplotlib.pyplot as plt"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "DPV2ldNcgwQY",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def show_plot(iteration,loss):\n",
        "    plt.plot(iteration,loss)\n",
        "    plt.show()\n",
        "    \n",
        "path='data'\n",
        "tokenizer = BertTokenizer(\"data/vocab.txt\")"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4lysaIxhgwQe",
        "colab_type": "code",
        "colab": {},
        "outputId": "f6876a60-1b00-4629-97c5-0e9dc995ad8d"
      },
      "source": [
        "input_ids = []\n",
        "input_types = []\n",
        "input_masks = []\n",
        "label = []\n",
        "pad_size = 64\n",
        " \n",
        "with open(\"data/data.txt\") as f:\n",
        "    for i, l in tqdm(enumerate(f)): \n",
        "        x1, x2, y = l.strip().split('\\t')\n",
        "        x1 = tokenizer.tokenize(x1)\n",
        "        x2 = tokenizer.tokenize(x2)\n",
        "        tokens = [\"[CLS]\"] + x1 + [\"[SEP]\"] + x2 +[\"[SEP]\"]\n",
        "        ids = tokenizer.convert_tokens_to_ids(tokens)\n",
        "        types = [0] *(len(ids))\n",
        "        masks = [1] * len(ids)\n",
        "\n",
        "        if len(ids) < pad_size:\n",
        "            types = types + [1] * (pad_size - len(ids))\n",
        "            masks = masks + [0] * (pad_size - len(ids))\n",
        "            ids = ids + [0] * (pad_size - len(ids))\n",
        "        else:\n",
        "            types = types[:pad_size]\n",
        "            masks = masks[:pad_size]\n",
        "            ids = ids[:pad_size]\n",
        "        input_ids.append(ids)\n",
        "        input_types.append(types)\n",
        "        input_masks.append(masks)\n",
        "        label.append([int(y)])"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "8398it [00:02, 2988.38it/s]\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xNnJOTNVgwQl",
        "colab_type": "code",
        "colab": {},
        "outputId": "ecb29dc6-40ae-43dd-b261-fa0a276b04c4"
      },
      "source": [
        "random_order = list(range(len(input_ids)))\n",
        "np.random.seed(2020)\n",
        "np.random.shuffle(random_order)\n",
        "print(random_order[:10])\n",
        "\n",
        "input_ids_train = np.array([input_ids[i] for i in random_order[:int(len(input_ids)*0.8)]])\n",
        "input_types_train = np.array([input_types[i] for i in random_order[:int(len(input_ids)*0.8)]])\n",
        "input_masks_train = np.array([input_masks[i] for i in random_order[:int(len(input_ids)*0.8)]])\n",
        "y_train = np.array([label[i] for i in random_order[:int(len(input_ids) * 0.8)]])\n",
        "print(input_ids_train.shape, input_types_train.shape, input_masks_train.shape, y_train.shape)\n",
        "\n",
        "input_ids_test = np.array([input_ids[i] for i in random_order[int(len(input_ids)*0.8):]])\n",
        "input_types_test = np.array([input_types[i] for i in random_order[int(len(input_ids)*0.8):]])\n",
        "input_masks_test = np.array([input_masks[i] for i in random_order[int(len(input_ids)*0.8):]])\n",
        "y_test = np.array([label[i] for i in random_order[int(len(input_ids) * 0.8):]])\n",
        "print(input_ids_test.shape, input_types_test.shape, input_masks_test.shape, y_test.shape)"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "[4602, 7440, 4762, 7897, 2306, 6203, 3410, 974, 1740, 4260]\n",
            "(6718, 64) (6718, 64) (6718, 64) (6718, 1)\n",
            "(1680, 64) (1680, 64) (1680, 64) (1680, 1)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "g3zozLA5gwQu",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "BATCH_SIZE = 32\n",
        "train_data = TensorDataset(torch.LongTensor(input_ids_train), \n",
        "                           torch.LongTensor(input_types_train), \n",
        "                           torch.LongTensor(input_masks_train), \n",
        "                           torch.LongTensor(y_train))\n",
        "train_sampler = RandomSampler(train_data)  \n",
        "train_loader = DataLoader(train_data, sampler=train_sampler, batch_size=BATCH_SIZE)\n",
        "\n",
        "test_data = TensorDataset(torch.LongTensor(input_ids_test), \n",
        "                          torch.LongTensor(input_types_test), \n",
        "                         torch.LongTensor(input_masks_test),\n",
        "                          torch.LongTensor(y_test))\n",
        "test_sampler = SequentialSampler(test_data)\n",
        "test_loader = DataLoader(test_data, sampler=test_sampler, batch_size=BATCH_SIZE)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "LnmOHgcHgwQz",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class Model(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(Model, self).__init__()\n",
        "        self.bert = BertModel.from_pretrained(path)\n",
        "        for param in self.bert.parameters():\n",
        "            param.requires_grad = True\n",
        "        self.fc = nn.Linear(768, 2)\n",
        "\n",
        "    def forward(self, x):\n",
        "        context = x[0]\n",
        "        types = x[1]\n",
        "        mask = x[2]\n",
        "        _, pooled = self.bert(context, token_type_ids=types, \n",
        "                              attention_mask=mask, \n",
        "                              output_all_encoded_layers=False)\n",
        "        out = self.fc(pooled)\n",
        "        return out"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "93mle9ZzgwQ6",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "model = Model().to(DEVICE) "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sFpmL2nCgwRD",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "param_optimizer = list(model.named_parameters())\n",
        "no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']\n",
        "optimizer_grouped_parameters = [\n",
        "    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},\n",
        "    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]\n",
        "NUM_EPOCHS = 1\n",
        "optimizer = BertAdam(optimizer_grouped_parameters,\n",
        "                     lr=2e-5,\n",
        "                     warmup=0.05,\n",
        "                     t_total=len(train_loader) * NUM_EPOCHS\n",
        "                    )"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "EIHFZTIPgwRL",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "print('TRAINNING START!')\n",
        "for epoch in range(NUM_EPOCHS):\n",
        "    correct = 0\n",
        "    total = 0\n",
        "    cost = 0\n",
        "    count = []\n",
        "    loss_his = []\n",
        "    acc_his = []\n",
        "    for i,(x1,x2,x3,y) in enumerate(train_loader):\n",
        "        x1,x2,x3,y = x1.to(device), x2.to(device), x3.to(device), y.to(device)\n",
        "        y_pred = model([x1, x2, x3])\n",
        "        model.zero_grad()\n",
        "        loss = F.cross_entropy(y_pred, y.squeeze())\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "        pred = y_pred.max(-1, keepdim=True)[1]   # .max(): 2输出，分别为最大值和最大值的index\n",
        "        acc += pred.eq(y.view_as(pred)).sum().item() \n",
        "        if(i + 1) % 100 == 0:\n",
        "            print(' Train Epoch: {} [{}/{} ({:.2f}%)]\\n Loss: {:.6f} \\n Accuracy: {}'.format(epoch,\n",
        "                                                                           (i+1) * len(x1), \n",
        "                                                                           len(train_loader.dataset),\n",
        "                                                                           100. * i / len(train_loader), \n",
        "                                                                           loss.item(),\n",
        "                                                                           100. * acc / len(train_loader)))\n",
        "        count.append(100. * i / len(train_loader))\n",
        "        loss_his.append(loss.item())\n",
        "        acc_his.append(100. * acc / len(train_loader))\n",
        "    PATH_checkpoint = './data/1.pth'\n",
        "    torch.save(net.state_dict(),PATH_checkpoint)\n",
        "print(\"TRAINNING END\")\n",
        "show_plot(counter,loss_his)\n",
        "show_plot(counter,acc_his)"
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}